#include "ast/typechecker.hpp"
#include "ast/visitor.hpp"
#include "ast/ast.hpp"
#include "type/type.hpp"
#include "runtime/universe.hpp"

#include <cstdio>
#include <cassert>

TypeScope::TypeScope(TypeScope* prev) {
    _prev = prev;
    _type_def = new std::map<std::string, Type*>();
}

TypeScope::~TypeScope() {
    _prev = NULL;

    if (_type_def != NULL) {
        _type_def->clear();
        delete _type_def;
        _type_def = NULL;
    }
}

EncloseScope::EncloseScope(EncloseScope* prev) {
    _prev = prev;
    _sym_table = new std::map<std::string, Type*>();
}

EncloseScope::~EncloseScope() {
    _prev = NULL;

    if (_sym_table != NULL) {
        _sym_table->clear();
        delete _sym_table;
        _sym_table = NULL;
    }
}

// @Return true if variable is defined first time.
// @Return false if conflicts.
bool EncloseScope::set_var_type(const char* name, Type* t) {
    auto r = _sym_table->find(name);
    if (r != _sym_table->end()) {
        printf("Redefination of variable %s\n", name);
        return false;
    }

    _sym_table->insert(std::pair<std::string, Type*>(std::string(name), t));
    return true;
}

Type* EncloseScope::get_var_type(const char* name) {
    EncloseScope* scope = this;
    Type* t = NULL;

    while (scope != NULL) {
        t = scope->find_variable(name);
        if (t != NULL) {
            break;
        }
        scope = scope->prev();
    }

    if (t == NULL) {
        printf("Compilation: Variable %s is not declared.\n", name);
    }

    return t;
}

Type* EncloseScope::find_variable(const char* name) {
    auto r = _sym_table->find(name);
    if (r != _sym_table->end()) {
        return r->second;
    }
    return NULL;
}

void EncloseScope::remove(const char* name) {
    auto r = _sym_table->find(name);
    if (r != _sym_table->end()) {
        _sym_table->erase(r);
    }
}

Type* TypeScope::find_local_type_def(const char* name) {
    auto r = _type_def->find(name);
    if (r != _type_def->end()) {
        return r->second;
    }

    return NULL;
}

Type* TypeScope::find_type_def(const char* name) {
    TypeScope* ts = this;
    while (ts != NULL) {
        Type* t = ts->find_local_type_def(name);
        if (t != NULL) {
            return t;
        }
        ts = ts->prev();
    }

    return NULL;
}

Type* TypeScope::apply_func(Type* func, TypeArgs* rargs) {
    if (!func || !func->is_type_func()) {
        printf("Error occurs.\n");
        exit(1);
    }

    TypeArgs* args = ((TypeFunction*)func)->args();
    assert(args->length() == rargs->length());

    TypeScope type_scope(this);

    for (int i = 0; i < args->length(); i++) {
        TypeVar* tv = (TypeVar*)(args->arguments()[i]);
        type_scope.save_type_def(tv->name(), 
                type_scope.transform(rargs->arguments()[i]));
    }

    Type* return_value = type_scope.transform(((TypeFunction*)func)->body());

    return return_value;
}

Type* TypeScope::transform(Type* t) {
    if (t == NULL) {
        return NULL;
    }

    if (t->is_primitive()) {
        return t;
    }
    else if (t->is_type_con()) {
        UserDefType* udt = (UserDefType*)t;
        Type* tp = find_type_def(udt->name());
        if (tp == NULL) {
            printf("Compliaton: Undefined type:'%s'\n", udt->name());
        }
        return tp;
    }
    else if (t->is_arrow()) {
        ArrowType* at = dynamic_cast<ArrowType*>(t);
        assert(at != NULL);
        ArrowType* new_at = new ArrowType(transform(at->src()), 
                transform(at->dst()));
        return new_at;
    }
    else if (t->is_type_apply()) {
        TypeFunctionApply* tfa = (TypeFunctionApply*)t;
        Type* func = transform(tfa->func());
        return apply_func(func, tfa->rargs());
    }

    return t;
}

bool TypeScope::save_type_def(const char* name, Type* t) {
    auto it = _type_def->find(name);
    if (it != _type_def->end()) {
        printf("Compliaton: Redefination of type '%s'\n", name);
        return false;
    }

    _type_def->insert(std::pair<std::string, Type*>(std::string(name), t));
    return true;
}

TypeChecker::TypeChecker() {
    _type = NULL;
    _status = true;
    _scope = new EncloseScope(NULL);

    //TODO : move to new method.
    _scope->set_var_type("print", new ArrowType(AnyType::get_instance(),
                IntType::get_instance()));
    _scope->set_var_type("println", new ArrowType(AnyType::get_instance(),
                IntType::get_instance()));

    _type_scope = new TypeScope(NULL);
}

TypeChecker::~TypeChecker() {
    _type = NULL;
    _status = false;

    delete _scope;
    delete _type_scope;
    _scope = NULL;
    _type_scope = NULL;

    Universe::clear_type_space();
}

void TypeChecker::visit(BitOrNode* n) {
}

void TypeChecker::visit(BitXorNode* n) {
}

void TypeChecker::visit(BitAndNode* n) {
}

void TypeChecker::visit(LeftShiftNode* n) {
}

void TypeChecker::visit(RightShiftNode* n) {
    check_arith(OPERATOR_SHIFT, n, ">>");
}

// Do not move this code into ListNode::accept.
// Because we do not know which visit order is used
// in Visitor.
void TypeChecker::visit(ListNode* n) {
    for (auto it = n->node_list()->begin(); it != n->node_list()->end(); it++) {
        (*it)->accept(this);
    }
}

void TypeChecker::check_arith(OperatorType op_tp, BinaryOp* op, const char* op_s) {
    op->left()->accept(this);
    Type* ltype = _type;
    op->right()->accept(this);
    Type* rtype = _type;

    bool is_left_any = ltype == AnyType::get_instance();
    bool is_right_any = rtype == AnyType::get_instance();

    if (is_left_any && is_right_any) {
        _type = AnyType::get_instance();
        return;
    }
    
    if (is_left_any) {
        ltype = rtype;
    }
    else if (is_right_any) {
        rtype = ltype;
    }

    // special rule 1: string * int
    if (op_tp == OPERATOR_MUL && ltype->equals(StringType::get_instance())
            && rtype->equals(IntType::get_instance())) {
        _type = StringType::get_instance();
        return;
    }

    if (!ltype->equals(rtype)) {
        printf("unsupported operand type(s) for %s: '%s' and '%s'\n", 
                op_s, ltype->to_string(), rtype->to_string());
        _type = NULL;
        _status = false;
    }

    if (!ltype->support_operator(op_tp)) {
        printf("unsupported operand type(s) for %s: '%s'\n",
                op_s, ltype->to_string());
        _type = NULL;
        _status = false;
    }
}

// '+' only support Int, String, Double, Char
void TypeChecker::visit(AddNode* n) {
    check_arith(OPERATOR_ADD, n, "+");
}

void TypeChecker::visit(SubNode* n) {
    check_arith(OPERATOR_SUB, n, "-");
}

void TypeChecker::visit(MulNode* n) {
    check_arith(OPERATOR_MUL, n, "*");
}

void TypeChecker::visit(DivNode* n) {
    check_arith(OPERATOR_DIV, n, "/");
}

void TypeChecker::visit(ConstInt* n) {
    _type = IntType::get_instance();
}

void TypeChecker::visit(ConstBool* n) {
    _type = BoolType::get_instance();
}

void TypeChecker::visit(ConstString* n) {
    _type = StringType::get_instance();
}

void TypeChecker::visit(ConstChar* n) {
    _type = CharType::get_instance();
}

void TypeChecker::visit(AssignNode* n) {
    n->left()->accept(this);
    Type* ltype = _type;
    n->right()->accept(this);
    Type* rtype = _type;

    if (!ltype->equals(rtype)) {
        printf("Compliation: Can not assign type '%s' to type '%s'.\n",
                rtype->to_string(), ltype->to_string());
        _status = false;
        _type = NULL;
    }
}

void TypeChecker::visit(VarNode* n) {
    _type = _scope->get_var_type(n->name());
    if (_type == NULL) {
        _status = false;
    }
}

void TypeChecker::visit(VarDefNode* n) {
    // var x;
    if (n->init_value() == NULL && n->type() == NULL) {
        printf("Declare variable '%s' without type.\n", n->name());
        _type = NULL;
        _status = false;
    }
    // var x:Int;
    else if (n->init_value() == NULL && n->type() != NULL) {
        // check redefination.
        if (!_scope->set_var_type(n->name(), _type_scope->transform(n->type()))) {
            _type = NULL;
            _status = false;
        }
    }
    // var x = 4 + 3;
    else if (n->init_value() != NULL) {
        n->init_value()->accept(this);
        // Error.
        if (_type == NULL) {
            assert(_status == false);
            return;
        }
        _type = _type_scope->transform(_type);

        // var x = 4 + 3;
        if (n->type() == NULL) {
            if (!_scope->set_var_type(n->name(), _type)) {
                _type = NULL;
                _status = false;
            }
        }
        else {
            // var x : String = 4 + 3;
            Type* ntype = _type_scope->transform(n->type());
            if (!ntype->equals(_type)) {
                printf("Can not assign type '%s' to '%s' with type '%s'.\n",
                        _type->to_string(), n->name(), ntype->to_string());

            }
            // var x : Int = 4 + 3;
            else {
                if (!_scope->set_var_type(n->name(), _type)) {
                    _type = NULL;
                    _status = false;
                }
            }
        }
    }
}

void TypeChecker::visit(TypeDefNode* n) {
    Type* solved = _type_scope->transform(n->def_type());
    if (!_type_scope->save_type_def(n->name(), solved)) {
        _status = false;
    }

    _type = NULL;
}

void TypeChecker::visit(LambdaDef* n) {
    Type* param_type = _type_scope->transform(n->param()->type());
    _scope = new EncloseScope(_scope);
    _scope->set_var_type(n->param()->name(), param_type);

    // return type
    n->body()->accept(this);

    _scope->remove(n->param()->name());
    EncloseScope* t = _scope;
    _scope = _scope->prev();
    delete t;

    _type = new ArrowType(param_type, _type);
}

void TypeChecker::visit(CallNode* n) {
    n->func_name()->accept(this);

    // Do not check Any type, and keep _type as Any.
    if (_type == AnyType::get_instance()) {
        return;
    }

    ArrowType* at = dynamic_cast<ArrowType*>(_type); // must be.
    if (at == NULL) {
        printf("Compliation: Can not call an unappliable object.\n");
        _status = false;
        _type = NULL;
        return;
    }

    // You can pass any value to 'AnyType' parameter.
    if (at->src() == AnyType::get_instance()) {
        return;
    }

    n->param()->accept(this);
    if (!_type->equals(at->src())) {
        printf("Compliation: Expect '%s', but get '%s'.\n", 
                at->src()->to_string(), _type->to_string());
        _status = false;
        _type = NULL;
    }

    _type = at->dst();
}

void TypeChecker::visit(PrintNode* n) {
    n->body()->accept(this);
    _type = NULL;
}

void TypeChecker::visit(PrintlnNode* n) {
    n->body()->accept(this);
    _type = NULL;
}

void TypeChecker::visit(CmpNode* n) {
    n->left()->accept(this);
    n->right()->accept(this);

    _type = BoolType::get_instance();
}

void TypeChecker::visit(IfNode* n) {
    n->cond()->accept(this);
    if (_type != BoolType::get_instance()) {
        printf("If condition's type is not bool\n");
        _status = false;
    }

    n->then_block()->accept(this);
    Type* then_type = _type;
    n->else_block()->accept(this);
    Type* else_type = _type;
    if (!then_type->equals(else_type)) {
        printf("Compilation: Different types in then block and else block\n");
        _status = false;
        _type = NULL;
    }
}

void TypeChecker::visit(LogicOrNode* n) {
    n->left()->accept(this);
    if (_type != BoolType::get_instance()) {
        printf("'||' does not take non-bool value as parameter.\n");
        _status = false;
    }

    n->right()->accept(this);
    if (_type != BoolType::get_instance()) {
        printf("'||' does not take non-bool value as parameter.\n");
        _status = false;
    }

    _type = BoolType::get_instance();
}

void TypeChecker::visit(LogicAndNode* n) {
    n->left()->accept(this);
    if (_type != BoolType::get_instance()) {
        printf("'&&' does not take non-bool value as parameter.\n");
        _status = false;
    }

    n->right()->accept(this);
    if (_type != BoolType::get_instance()) {
        printf("'&&' does not take non-bool value as parameter.\n");
        _status = false;
    }

    _type = BoolType::get_instance();
}

void TypeChecker::visit(LogicNotNode* n) {
    n->value()->accept(this);
    if (_type != BoolType::get_instance()) {
        printf("'!' does not take non-bool value as parameter.\n");
        _status = false;
    }

    _type = BoolType::get_instance();
}

